# ECAM Paper Experiment Design Guide

This document provides a detailed guide for designing and executing experiments for the paper "ECAM: Enhancing Causal Reasoning in Foundation Models with Endogenous Causal Attention Mechanism", including experimental design, data collection methods, required datasets, comparison models, and visualization charts.

## 1. Overall Experimental Framework

The experimental evaluation of ECAM is divided into four main parts:
1. Causal discovery performance evaluation
2. Causal effect estimation evaluation
3. Downstream task performance evaluation
4. Ablation studies

Each part requires different datasets, baseline models, and evaluation metrics.



## 2. Dataset Selection and Preparation

### 2.1 Synthetic Datasets

**Purpose**: To provide a controlled environment with known causal structures for evaluating causal discovery and effect estimation performance.

**Generation Method**:
1. **Graph Structure Generation**:
   - Erdős–Rényi (ER) random graphs: Use the `networkx` library, parameters `n=10-50` (number of nodes), `p=0.1-0.3` (edge probability)
   - Scale-Free (SF) graphs: Use the Barabási–Albert model from `networkx`, parameters `n=10-50`, `m=1-3` (number of edges to attach from a new node to existing nodes)
   - Ensure generated graphs are Directed Acyclic Graphs (DAGs)

2. **Data Generation**:
   - **Linear SCM**: $X_i = \sum_{j \in PA_i} w_{ji}X_j + U_i$, where $U_i \sim \mathcal{N}(0, \sigma^2)$
   - **Non-linear SCM**: $X_i = f_i(\sum_{j \in PA_i} w_{ji}X_j) + U_i$, where $f_i$ can be sigmoid, tanh, or polynomial functions
   - Weights $w_{ji}$ are uniformly sampled from $[-2.0, -0.5] \cup [0.5, 2.0]$ to avoid weak causal effects close to zero
   - Noise variance $\sigma^2$ is set to 0.1-0.5 to control the signal-to-noise ratio

3. **Dataset Scale**:
   - Training set: 5000 samples generated for each graph structure
   - Validation set: 1000 samples
   - Test set: 2000 samples

4. **Interventional Data**:
   - Perform interventions $do(X_i=x)$ on each node $X_i$, where $x$ is uniformly sampled from $[-2, 2]$
   - Generate 100 interventional samples for each node to evaluate causal effect estimation

**Code Implementation**:
```python
import numpy as np
import networkx as nx

def generate_random_dag(n, p, graph_type='er'):
    """Generate a random DAG"""
    if graph_type == 'er':
        # Generate ER random graph
        G = nx.DiGraph(nx.gnp_random_graph(n=n, p=p, directed=True))
    elif graph_type == 'sf':
        # Generate undirected SF graph then convert to directed
        G_undir = nx.barabasi_albert_graph(n=n, m=2)
        G = nx.DiGraph()
        G.add_nodes_from(range(n))
        for (i, j) in G_undir.edges():
            if i < j:  # Ensure directionality
                G.add_edge(i, j)
            else:
                G.add_edge(j, i)
    
    # Ensure acyclic
    cycles = list(nx.simple_cycles(G))
    while cycles:
        for cycle in cycles:
            G.remove_edge(cycle[0], cycle[1])
        cycles = list(nx.simple_cycles(G))
    
    return G

def generate_weights(G):
    """Generate edge weights for the DAG"""
    weights = {}
    for (i, j) in G.edges():
        # Avoid weights close to zero
        if np.random.random() < 0.5:
            weights[(i, j)] = np.random.uniform(0.5, 2.0)
        else:
            weights[(i, j)] = np.random.uniform(-2.0, -0.5)
    return weights

def generate_linear_scm_data(G, weights, n_samples, noise_scale=0.1):
    """Generate linear SCM data"""
    n = G.number_of_nodes()
    X = np.zeros((n_samples, n))
    
    # Topological sort ensures generation in causal order
    for node in nx.topological_sort(G):
        parents = list(G.predecessors(node))
        if not parents:  # Root node
            X[:, node] = np.random.normal(0, 1, n_samples)
        else:
            # Linear combination of parent nodes
            X[:, node] = sum(weights[(p, node)] * X[:, p] for p in parents)
            # Add noise
            X[:, node] += np.random.normal(0, noise_scale, n_samples)
    
    return X

def generate_intervention_data(G, weights, node, value, n_samples, noise_scale=0.1):
    """Generate interventional data do(X_node=value)"""
    n = G.number_of_nodes()
    X = np.zeros((n_samples, n))
    
    # Set intervention value
    X[:, node] = value
    
    # Topological sort
    for i in nx.topological_sort(G):
        if i == node:  # Skip intervened node
            continue
        
        parents = list(G.predecessors(i))
        if not parents:  # Root node
            X[:, i] = np.random.normal(0, 1, n_samples)
        else:
            # Linear combination of parent nodes
            X[:, i] = sum(weights[(p, i)] * X[:, p] for p in parents if p != node)
            # Add noise
            X[:, i] += np.random.normal(0, noise_scale, n_samples)
    
    return X
```

### 2.2 Real-world Benchmark Datasets

#### 2.2.1 Causal Discovery Datasets

**Tübingen Cause-Effect Pairs**
- **Source**: [https://webdav.tuebingen.mpg.de/cause-effect/](https://webdav.tuebingen.mpg.de/cause-effect/)
- **Description**: Contains 108 pairs of variables, each with a true causal relationship label.
- **Preprocessing**:
  - Standardize all variables.
  - Split the dataset into training (70%), validation (10%), and test (20%) sets.
  - For multivariate data, extract relevant variable pairs.

#### 2.2.2 Downstream Task Datasets

**Natural Language Inference (NLI) Datasets**
- **GLUE Benchmark**:
  - **Source**: [https://gluebenchmark.com/](https://gluebenchmark.com/)
  - **Subsets**: MNLI, RTE, QNLI (with special attention to samples requiring causal reasoning).
  - **Preprocessing**: Use standard splits, extract subsets of samples that require causal reasoning.

- **CLUTRR**:
  - **Source**: [https://github.com/facebookresearch/clutrr](https://github.com/facebookresearch/clutrr)
  - **Description**: A dataset specifically designed to test relational and causal reasoning abilities.
  - **Preprocessing**: Use 2-4 hop relational reasoning tasks, standard train/validation/test splits.

**Visual Question Answering (VQA) Datasets**
- **VQA v2.0**:
  - **Source**: [https://visualqa.org/](https://visualqa.org/)
  - **Preprocessing**:
    - Extract subsets involving causal questions (e.g., "why...", "what caused...", "what if...").
    - Create a causal VQA subset using keyword filtering and manual verification.
    - Standard train/validation/test splits.

- **GQA**:
  - **Source**: [https://cs.stanford.edu/people/dorarad/gqa/](https://cs.stanford.edu/people/dorarad/gqa/)
  - **Preprocessing**: Similar to VQA, extract causally relevant questions.



## 3. Detailed Experimental Design

### 3.1 Causal Discovery Experiments

**Objective**: To evaluate ECAM's ability to recover underlying causal structures.

**Experimental Setup**:
1. **Datasets**:
   - Synthetic data: Data generated from ER and SF graphs.
   - Real-world data: Tübingen Cause-Effect Pairs.

2. **Baseline Models**:
   - **PC Algorithm**: Implemented using the `causal-learn` library.
   - **GES Algorithm**: Implemented using the `causal-learn` library.
   - **NOTEARS**: Using the original author's implementation [https://github.com/xunzheng/notears](https://github.com/xunzheng/notears)
   - **Attn+Mask**: Standard attention + post-processing graph mask (self-implemented).
   - **CATT**: Causal Attention (implemented based on the original paper).

3. **Evaluation Metrics**:
   - **SHD (Structural Hamming Distance)**: Edit distance between the predicted graph and the true graph.
   - **SID (Structural Intervention Distance)**: Measures the accuracy of intervention predictions.
   - **F1 Score**: Harmonic mean of precision and recall for edge prediction.

4. **Experimental Procedure**:
   - For each graph type and size, generate 10 different random graph structures.
   - Generate datasets for each graph structure.
   - Train ECAM and baseline models.
   - Evaluate the difference between the recovered causal graph and the true graph.
   - Calculate average performance metrics and standard deviations.

5. **Hyperparameter Tuning**:
   - Use grid search or Bayesian optimization to tune key hyperparameters.
   - Key hyperparameters for ECAM: learning rate, regularization strength, graph learning weight, modularity function type.

### 3.2 Causal Effect Estimation Experiments

**Objective**: To evaluate ECAM's ability to predict the effects of interventions.

**Experimental Setup**:
1. **Dataset**:
   - Synthetic data: Use the same graph structures as in 3.1, but include interventional data.

2. **Baseline Models**:
   - **Standard Transformer**: No causal enhancement.
   - **IRM**: Invariant Risk Minimization.
   - **Attn+Mask** and **CATT**: Same as in 3.1.

3. **Evaluation Metrics**:
   - **PEHE (Precision in Estimation of Heterogeneous Effect)**: $\sqrt{\frac{1}{n}\sum_i(Y_i(1)-Y_i(0) - (\hat{Y}_i(1)-\hat{Y}_i(0)))^2}$
   - **ATE MSE**: Mean Squared Error of the Average Treatment Effect.

4. **Experimental Procedure**:
   - Train models using observational data.
   - Predict interventional effects for samples in the test set.
   - Compare with true interventional effects.
   - Calculate evaluation metrics.

### 3.3 Downstream Task Experiments

**Objective**: To evaluate the effectiveness of ECAM in enhancing the causal reasoning capabilities of foundation models.

**Experimental Setup**:
1. **NLI Task**:
   - **Datasets**: GLUE (MNLI, RTE, QNLI) and CLUTRR.
   - **Base Model**: BERT-base (110M parameters).
   - **Variants**:
     - BERT (original)
     - BERT+IRM
     - BERT+Attn+Mask
     - BERT+CATT
     - BERT+ECAM (our method)
   - **Evaluation Metrics**: Accuracy, F1 score, GLUE average score.

2. **VQA Task**:
   - **Datasets**: VQA v2.0 and GQA (causal question subsets).
   - **Base Model**: ViT-B/16 (86M parameters).
   - **Variants**: Similar to the NLI task.
   - **Evaluation Metrics**: VQA score, accuracy.

3. **Experimental Procedure**:
   - For each base model, implement the corresponding attention variant.
   - Fine-tune on downstream task datasets.
   - Evaluate performance on the test set.
   - Pay special attention to subsets of samples requiring causal reasoning.

### 3.4 Ablation Studies

**Objective**: To analyze the contribution of each component of ECAM.

**Experimental Setup**:
1. **Variants**:
   - **Full ECAM**: Includes all components.
   - **No LCG Learning**: Replace the learned causal graph with an identity matrix.
   - **No Intervention Mechanism**: Disable interventional attention computation.
   - **No Counterfactual Mechanism**: Disable counterfactual attention computation.
   - **Binary Mask M(G)**: Replace continuous weights with a binary mask.
   - **Correlation-based Graph**: Use simple correlation thresholding instead of optimized LCG learning.

2. **Evaluation Tasks**:
   - Evaluate each variant on representative tasks from 3.1-3.3.
   - Calculate the percentage performance relative to the full ECAM.

3. **Experimental Procedure**:
   - Use the same training and evaluation pipeline for each variant.
   - Ensure fair comparison (same random seeds, initialization, etc.).



## 4. Implementation Details

### 4.1 ECAM Implementation

**Framework**: PyTorch

**Core Components**:
```python
class ECAM(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1, graph_reg=0.01):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Standard attention components
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        
        # Causal graph learning component
        self.graph_learner = GraphLearner(d_model, graph_reg)
        
        # Intervention and counterfactual components
        self.intervention_module = InterventionModule(d_model)
        self.counterfactual_module = CounterfactualModule(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None, intervention=None):
        # Standard attention calculation
        q = self.wq(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.wk(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.wv(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Learn local causal graph
        G = self.graph_learner(x)
        
        # Calculate attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply causal graph modulation
        causal_mask = self.modulate_with_graph(G)
        scores = scores * causal_mask
        
        # If there is an intervention
        if intervention is not None:
            scores = self.intervention_module(scores, intervention, G)
        
        # Apply softmax
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        # Calculate output
        output = torch.matmul(attn, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.wo(output)
        
        return output, attn, G
```

**Key Hyperparameters**:
- Learning rate: 1e-4 to 5e-4
- Batch size: 16 to 64
- Causal graph regularization strength: 0.01 to 0.1
- Number of attention heads: 4 to 12
- Model dimension: 256 to 768

### 4.2 Baseline Model Implementation

**Standard Transformer**: Use Hugging Face's implementation.
**PC/GES**: Use the `causal-learn` library.
**NOTEARS**: Use the original author's implementation.
**CATT**: Implement based on the original paper.
**IRM**: Use a public implementation [https://github.com/facebookresearch/InvariantRiskMinimization](https://github.com/facebookresearch/InvariantRiskMinimization)

## 5. Required Charts and Visualizations

### 5.1 Causal Discovery Performance Charts

1. **SHD/SID Comparison Bar Chart**:
   - X-axis: Different methods (PC, GES, NOTEARS, Attn+Mask, CATT, ECAM)
   - Y-axis: SHD and SID values (lower is better)
   - Grouped by: ER graphs and SF graphs
   - Error bars: Standard deviation

2. **F1 Score Comparison Chart**:
   - Similar to the above, but Y-axis is F1 score (higher is better)

3. **Performance Curve vs. Number of Nodes**:
   - X-axis: Number of nodes (10 to 50)
   - Y-axis: SHD or F1
   - Multiple lines: Different methods

4. **Learned Causal Graph Visualization**:
   - Compare the true graph with the ECAM-learned graph.
   - Use different colors to mark correct edges, missing edges, and extra edges.

### 5.2 Causal Effect Estimation Charts

1. **PEHE/ATE MSE Comparison Bar Chart**:
   - X-axis: Different methods
   - Y-axis: PEHE and ATE MSE (lower is better)
   - Error bars: Standard deviation

2. **Interventional Effect Prediction Scatter Plot**:
   - X-axis: True interventional effect
   - Y-axis: Predicted interventional effect
   - Different colors: Different methods
   - Ideal case: Points fall on the diagonal line.

### 5.3 Downstream Task Performance Charts

1. **NLI Performance Comparison Bar Chart**:
   - X-axis: Different methods (BERT, BERT+IRM, etc.)
   - Y-axis: Accuracy/F1 score
   - Grouped by: Different datasets (MNLI, RTE, QNLI, CLUTRR)

2. **VQA Performance Comparison Bar Chart**:
   - Similar to the above, but for VQA tasks.

3. **Causal vs. Non-causal Question Performance Comparison**:
   - Compare model performance differences on causally relevant questions and non-causal questions.

### 5.4 Ablation Study Charts

1. **Ablation Study Bar Chart**:
   - X-axis: Different ECAM variants
   - Y-axis: Relative performance percentage
   - Grouped by: Different tasks (causal discovery, effect estimation, downstream tasks)

2. **Attention Visualization**:
   - Heatmaps showing attention weights of different variants.
   - Pay special attention to attention changes before and after intervention.

### 5.5 Charts Requiring Data Collection Through Experiments

The following charts require data collection by running experiments:
- All performance comparison bar charts and curves (items 1-3 in 5.1-5.4)
- Interventional effect prediction scatter plot
- Ablation study charts
- Attention visualization

Charts that can be designed beforehand:
- Causal graph structure examples
- Method architecture diagram
- Experimental flowchart

## 6. Experimental Resource Requirements

### 6.1 Computational Resources

- **GPU**: At least one NVIDIA V100 or A100 GPU is required.
- **Memory**: At least 32GB RAM.
- **Storage**: Approximately 100GB for datasets and model checkpoints.
- **Estimated Training Time**:
  - Synthetic data experiments: Approximately 24-48 hours.
  - Downstream task experiments: Approximately 72-120 hours (depending on model size and dataset).

### 6.2 Software Dependencies

- Python 3.8+
- PyTorch 1.10+
- Hugging Face Transformers
- NetworkX
- Causal-learn
- Scikit-learn
- Matplotlib/Seaborn (for visualization)
- Pandas
- NumPy

## 7. Experimental Timeline

1. **Data Preparation** (1-2 weeks):
   - Synthetic data generation.
   - Real-world dataset acquisition and preprocessing.
   - Creation of train/validation/test splits.

2. **Model Implementation** (2-3 weeks):
   - ECAM core component implementation.
   - Baseline model implementation/adaptation.
   - Unit testing and validation.

3. **Causal Discovery Experiments** (1-2 weeks):
   - Run all models.
   - Collect results and analyze.

4. **Causal Effect Estimation Experiments** (1-2 weeks):
   - Implement intervention mechanisms.
   - Run experiments and analyze.

5. **Downstream Task Experiments** (2-3 weeks):
   - Integrate ECAM into base models.
   - Fine-tune and evaluate.
   - Result analysis.

6. **Ablation Studies** (1 week):
   - Implement variants.
   - Run comparative experiments.

7. **Result Visualization and Analysis** (1-2 weeks):
   - Create all charts.
   - Statistical analysis.
   - Write result interpretations.

Total: Approximately 9-15 weeks of experimental time.
